package modules

import (
	"context"

	"go-zero-admin/apps/system/cmd/rpc/internal/config"
	dbModel "go-zero-admin/apps/system/model"

	"github.com/Masterminds/squirrel"
	"github.com/casbin/casbin/v2"
	"github.com/casbin/casbin/v2/model"
	"github.com/casbin/casbin/v2/persist"
	"github.com/zeromicro/go-zero/core/fx"
	"github.com/zeromicro/go-zero/core/stores/sqlx"
)

type AdapterCasbin struct {
	// 创建
	config             config.Config
	sysCasbinRuleModel dbModel.SysCasbinRuleModel
	// 赋值
	Ctx     context.Context
	Session sqlx.Session
	// 返回
	Enforcer    *casbin.SyncedEnforcer
	EnforcerErr error
}

// NewCasbinServer 创建casbin服务
func NewCasbinServer(c config.Config, conn sqlx.SqlConn) *AdapterCasbin {
	ac := &AdapterCasbin{
		config:             c,
		sysCasbinRuleModel: dbModel.NewSysCasbinRuleModel(conn, c.ModelCache),
		Ctx:                context.Background(),
	}

	ac.initPolicy()
	return ac
}

// SetCtxSession 设置ctx和session
func (l *AdapterCasbin) SetCtxSession(ctx context.Context, session sqlx.Session) *AdapterCasbin {
	l.Ctx = ctx
	l.Session = session
	return l
}

// 初始化代理
func (l *AdapterCasbin) initPolicy() {
	e, err := casbin.NewSyncedEnforcer(l.config.SystemCustom.CasBinFile.ModelFile, l)
	if err != nil {
		l.EnforcerErr = err
		return
	}

	// Clear the current policy.
	e.ClearPolicy()
	l.Enforcer = e

	// Load the policy from DB.
	err = l.LoadPolicy(e.GetModel())
	if err != nil {
		l.EnforcerErr = err
		return
	}
}

// LoadPolicy loads policy from database.
func (l *AdapterCasbin) LoadPolicy(model model.Model) error {
	var lines []*dbModel.SysCasbinRule
	lines, err := l.sysCasbinRuleModel.FindAll(l.Ctx, l.sysCasbinRuleModel.RowBuilder(), "")
	if err != nil {
		return err
	}

	if len(lines) == 0 {
		return nil
	}

	for _, line := range lines {
		l.loadPolicyLine(line, model)
	}

	return nil
}

// SavePolicy saves policy to database.
func (l *AdapterCasbin) SavePolicy(model model.Model) (err error) {
	err = l.dropTable()
	if err != nil {
		return err
	}

	err = l.createTable()
	if err != nil {
		return err
	}

	for ptype, ast := range model["p"] {
		for _, rule := range ast.Policy {
			line := l.savePolicyLine(ptype, rule)
			_, err = l.sysCasbinRuleModel.InsertEx(l.Ctx, l.Session, []*dbModel.SysCasbinRule{line})
			if err != nil {
				return err
			}
		}
	}

	for ptype, ast := range model["g"] {
		for _, rule := range ast.Policy {
			line := l.savePolicyLine(ptype, rule)
			_, err = l.sysCasbinRuleModel.InsertEx(l.Ctx, l.Session, []*dbModel.SysCasbinRule{line})
			if err != nil {
				return err
			}
		}
	}
	return
}

// AddPolicy adds a policy rule to the storage.
func (l *AdapterCasbin) AddPolicy(sec string, ptype string, rule []string) error {
	line := l.savePolicyLine(ptype, rule)
	_, err := l.sysCasbinRuleModel.InsertEx(l.Ctx, l.Session, []*dbModel.SysCasbinRule{line})
	return err
}

// RemovePolicy removes a policy rule from the storage.
func (l *AdapterCasbin) RemovePolicy(sec string, ptype string, rule []string) error {
	line := l.savePolicyLine(ptype, rule)
	err := l.rawDelete(line)
	return err
}

// RemoveFilteredPolicy removes policy rules that match the filter from the storage.
func (l *AdapterCasbin) RemoveFilteredPolicy(sec string, ptype string, fieldIndex int, fieldValues ...string) error {
	line := &dbModel.SysCasbinRule{}
	line.Ptype = ptype
	if fieldIndex <= 0 && 0 < fieldIndex+len(fieldValues) {
		line.V0 = fieldValues[0-fieldIndex]
	}
	if fieldIndex <= 1 && 1 < fieldIndex+len(fieldValues) {
		line.V1 = fieldValues[1-fieldIndex]
	}
	if fieldIndex <= 2 && 2 < fieldIndex+len(fieldValues) {
		line.V2 = fieldValues[2-fieldIndex]
	}
	if fieldIndex <= 3 && 3 < fieldIndex+len(fieldValues) {
		line.V3 = fieldValues[3-fieldIndex]
	}
	if fieldIndex <= 4 && 4 < fieldIndex+len(fieldValues) {
		line.V4 = fieldValues[4-fieldIndex]
	}
	if fieldIndex <= 5 && 5 < fieldIndex+len(fieldValues) {
		line.V5 = fieldValues[5-fieldIndex]
	}
	err := l.rawDelete(line)
	return err
}

func (l *AdapterCasbin) loadPolicyLine(line *dbModel.SysCasbinRule, model model.Model) {
	lineText := line.Ptype
	if line.V0 != "" {
		lineText += ", " + line.V0
	}
	if line.V1 != "" {
		lineText += ", " + line.V1
	}
	if line.V2 != "" {
		lineText += ", " + line.V2
	}
	if line.V3 != "" {
		lineText += ", " + line.V3
	}
	if line.V4 != "" {
		lineText += ", " + line.V4
	}
	if line.V5 != "" {
		lineText += ", " + line.V5
	}
	persist.LoadPolicyLine(lineText, model)
}

func (l *AdapterCasbin) dropTable() (err error) {
	return
}

func (l *AdapterCasbin) createTable() (err error) {
	return
}

func (l *AdapterCasbin) savePolicyLine(ptype string, rule []string) *dbModel.SysCasbinRule {
	line := &dbModel.SysCasbinRule{}
	line.Ptype = ptype
	if len(rule) > 0 {
		line.V0 = rule[0]
	}
	if len(rule) > 1 {
		line.V1 = rule[1]
	}
	if len(rule) > 2 {
		line.V2 = rule[2]
	}
	if len(rule) > 3 {
		line.V3 = rule[3]
	}
	if len(rule) > 4 {
		line.V4 = rule[4]
	}
	if len(rule) > 5 {
		line.V5 = rule[5]
	}
	return line
}

func (l *AdapterCasbin) rawDelete(line *dbModel.SysCasbinRule) error {
	whereBuilder := l.sysCasbinRuleModel.RowBuilder().Where(squirrel.Eq{"ptype": &line.Ptype})
	if line.V0 != "" {
		whereBuilder = whereBuilder.Where(squirrel.Eq{"v0": line.V0})
	}
	if line.V1 != "" {
		whereBuilder = whereBuilder.Where(squirrel.Eq{"v1": line.V1})
	}
	if line.V2 != "" {
		whereBuilder = whereBuilder.Where(squirrel.Eq{"v2": line.V2})
	}
	if line.V3 != "" {
		whereBuilder = whereBuilder.Where(squirrel.Eq{"v3": line.V3})
	}
	if line.V4 != "" {
		whereBuilder = whereBuilder.Where(squirrel.Eq{"v4": line.V4})
	}
	if line.V5 != "" {
		whereBuilder = whereBuilder.Where(squirrel.Eq{"v5": line.V5})
	}
	list, err := l.sysCasbinRuleModel.FindAll(l.Ctx, whereBuilder, "")
	if err != nil {
		return err
	}

	// 组装id
	idList := make([]int64, 0)
	fx.From(func(source chan<- interface{}) {
		for _, v := range list {
			source <- v
		}
	}).Map(func(item interface{}) interface{} {
		return item.(*dbModel.SysCasbinRule).Id
	}).ForEach(func(item interface{}) {
		idList = append(idList, item.(int64))
	})

	if len(idList) == 0 {
		return nil
	}

	// 批量删除
	err = l.sysCasbinRuleModel.DeleteEx(l.Ctx, l.Session, idList, true, "")
	if err != nil {
		return err
	}

	return nil
}
